{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "85b9cad9-14cc-4145-a71a-6fdfa7b9b044",
   "metadata": {},
   "source": [
    "# ROUGE-1 (ROUGE for Unigrams)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d152300f-00fd-4f31-bb2c-a776f1822a2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "reference = \"The quick brown fox jumps over the lazy dog\"\n",
    "candidate = \"The fox jumps over the dog\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3677a7a3-d4d1-49a0-ab90-d02b526cde04",
   "metadata": {},
   "source": [
    "### `rouge` library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6a159771-e2ee-4adc-b53d-d349d69a6a63",
   "metadata": {},
   "outputs": [],
   "source": [
    "#pip install rouge"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cb45c369-e04f-43ea-a31f-e2380efa3e79",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ROUGE-1 Recall: 0.67\n",
      "ROUGE-1 Precision: 1.00\n",
      "ROUGE-1 F1: 0.80\n"
     ]
    }
   ],
   "source": [
    "from rouge import Rouge\n",
    "\n",
    "rouge = Rouge()\n",
    "\n",
    "scores = rouge.get_scores(candidate, reference, avg=True)\n",
    "\n",
    "print(f\"ROUGE-1 Recall: {scores['rouge-1']['r']:.2f}\")\n",
    "print(f\"ROUGE-1 Precision: {scores['rouge-1']['p']:.2f}\")\n",
    "print(f\"ROUGE-1 F1: {scores['rouge-1']['f']:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d89def98-71e9-4ee4-96ac-017dd9bf2a28",
   "metadata": {},
   "source": [
    "### TorchMetrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "01198d15-c16c-4aa6-a8ed-65705ab743cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ROUGE-1 Recall: 0.67\n",
      "ROUGE-1 Precision: 1.00\n",
      "ROUGE-1 F1: 0.80\n"
     ]
    }
   ],
   "source": [
    "from torchmetrics.text import ROUGEScore\n",
    "\n",
    "rouge = ROUGEScore(n_gram=1)\n",
    "\n",
    "# Calculate ROUGE scores\n",
    "rouge_score = rouge(target=[reference], preds=[candidate])\n",
    "\n",
    "print(f\"ROUGE-1 Recall: {rouge_score['rouge1_recall']:.2f}\")\n",
    "print(f\"ROUGE-1 Precision: {rouge_score['rouge1_precision']:.2f}\")\n",
    "print(f\"ROUGE-1 F1: {rouge_score['rouge1_fmeasure']:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fe6a608-168a-4f7c-87a0-d82b9eac21ac",
   "metadata": {},
   "source": [
    "### From Scratch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "66075150-bb85-465f-945e-3cc008de7efd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ROUGE-1 Recall: 0.67\n",
      "ROUGE-1 Precision: 1.00\n",
      "ROUGE-1 F1: 0.80\n"
     ]
    }
   ],
   "source": [
    "from collections import Counter\n",
    "\n",
    "def tokenize(sentence):\n",
    "    return sentence.lower().split()\n",
    "\n",
    "def ngrams(tokens, n):\n",
    "    return [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]\n",
    "\n",
    "def rouge_1(candidate, reference):\n",
    "    candidate_tokens = tokenize(candidate)\n",
    "    reference_tokens = tokenize(reference)\n",
    "\n",
    "    candidate_1grams = Counter(ngrams(candidate_tokens, 1))\n",
    "    reference_1grams = Counter(ngrams(reference_tokens, 1))\n",
    "\n",
    "    overlapping_1grams = candidate_1grams & reference_1grams\n",
    "    overlap_count = sum(overlapping_1grams.values())\n",
    "\n",
    "    candidate_count = sum(candidate_1grams.values())\n",
    "    reference_count = sum(reference_1grams.values())\n",
    "\n",
    "    if candidate_count == 0 or reference_count == 0:\n",
    "        return 0\n",
    "\n",
    "    precision = overlap_count / candidate_count\n",
    "    recall = overlap_count / reference_count\n",
    "    f1_score = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0\n",
    "\n",
    "    return precision, recall, f1_score\n",
    "\n",
    "\n",
    "precision, recall, f1_score = rouge_1(candidate, reference)\n",
    "\n",
    "print(f\"ROUGE-1 Recall: {recall:.2f}\")\n",
    "print(f\"ROUGE-1 Precision: {precision:.2f}\")\n",
    "print(f\"ROUGE-1 F1: {f1_score:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59d70b91-bd09-43a8-839a-82174973b546",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29882a0b-2b34-4b6b-b371-835ffe7f3cbe",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
